import numpy as np
import pickle as pkl
import networkx as nx
import scipy.sparse as sp
from scipy.sparse.linalg.eigen.arpack import eigsh
import sys
import torch
import torch.nn as nn


def parse_skipgram(fname):
    with open(fname) as f:
        toks = list(f.read().split())
    nb_nodes = int(toks[0])
    nb_features = int(toks[1])
    ret = np.empty((nb_nodes, nb_features))
    it = 2
    for i in range(nb_nodes):
        cur_nd = int(toks[it]) - 1
        it += 1
        for j in range(nb_features):
            cur_ft = float(toks[it])
            ret[cur_nd][j] = cur_ft
            it += 1
    return ret

# Process a (subset of) a TU dataset into standard form
def process_tu(data, nb_nodes):
    nb_graphs = len(data)
    ft_size = data.num_features

    features = np.zeros((nb_graphs, nb_nodes, ft_size))
    adjacency = np.zeros((nb_graphs, nb_nodes, nb_nodes))
    labels = np.zeros(nb_graphs)
    sizes = np.zeros(nb_graphs, dtype=np.int32)
    masks = np.zeros((nb_graphs, nb_nodes))
       
    for g in range(nb_graphs):
        sizes[g] = data[g].x.shape[0]
        features[g, :sizes[g]] = data[g].x
        labels[g] = data[g].y[0]
        masks[g, :sizes[g]] = 1.0
        e_ind = data[g].edge_index
        coo = sp.coo_matrix((np.ones(e_ind.shape[1]), (e_ind[0, :], e_ind[1, :])), shape=(nb_nodes, nb_nodes))
        adjacency[g] = coo.todense()

    return features, adjacency, labels, sizes, masks

def micro_f1(logits, labels):
    # Compute predictions
    preds = torch.round(nn.Sigmoid()(logits))
    
    # Cast to avoid trouble
    preds = preds.long()
    labels = labels.long()

    # Count true positives, true negatives, false positives, false negatives
    tp = torch.nonzero(preds * labels).shape[0] * 1.0
    tn = torch.nonzero((preds - 1) * (labels - 1)).shape[0] * 1.0
    fp = torch.nonzero(preds * (labels - 1)).shape[0] * 1.0
    fn = torch.nonzero((preds - 1) * labels).shape[0] * 1.0

    # Compute micro-f1 score
    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f1 = (2 * prec * rec) / (prec + rec)
    return f1

"""
 Prepare adjacency matrix by expanding up to a given neighbourhood.
 This will insert loops on every node.
 Finally, the matrix is converted to bias vectors.
 Expected shape: [graph, nodes, nodes]
"""
def adj_to_bias(adj, sizes, nhood=1):
    nb_graphs = adj.shape[0]
    mt = np.empty(adj.shape)
    for g in range(nb_graphs):
        mt[g] = np.eye(adj.shape[1])
        for _ in range(nhood):
            mt[g] = np.matmul(mt[g], (adj[g] + np.eye(adj.shape[1])))
        for i in range(sizes[g]):
            for j in range(sizes[g]):
                if mt[g][i][j] > 0.0:
                    mt[g][i][j] = 1.0
    return -1e9 * (1.0 - mt)


###############################################
# This section of code adapted from tkipf/gcn #
###############################################

def parse_index_file(filename):
    """Parse index file."""
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index

def sample_mask(idx, l):
    """Create mask."""
    mask = np.zeros(l)
    mask[idx] = 1
    return np.array(mask, dtype=np.bool)


def load_data():
    with open(r'F:\mypython\final_subject\DWGI\data\adj', 'rb') as f:
        adj = pkl.load(f)

    with open(r'F:\mypython\final_subject\DWGI\data\features', 'rb') as f:
        features = pkl.load(f)

    with open(r'F:\mypython\final_subject\DWGI\data\index', 'rb') as f:
        index = pkl.load(f)

    with open(r'F:\mypython\final_subject\DWGI\data\node_subgraph', 'rb') as f:
        node_subgraph = pkl.load(f)

    with open(r'F:\mypython\final_subject\DWGI\data\subgraph_node', 'rb') as f:
        subgraph_node = pkl.load(f)

    # 获得单词和文档之间的边
    doc_word_edge = get_doc_word_edge(index)

    return adj, features, index, node_subgraph, subgraph_node, doc_word_edge


def load_data1(dataset_str): # {'pubmed', 'citeseer', 'cora'}
    """Load data."""
    names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
    objects = []
    for i in range(len(names)):
        with open("data/ind.{}.{}".format(dataset_str, names[i]), 'rb') as f:
            if sys.version_info > (3, 0):
                objects.append(pkl.load(f, encoding='latin1'))
            else:
                objects.append(pkl.load(f))

    x, y, tx, ty, allx, ally, graph = tuple(objects)
    test_idx_reorder = parse_index_file("data/ind.{}.test.index".format(dataset_str))
    test_idx_range = np.sort(test_idx_reorder)

    if dataset_str == 'citeseer':
        # Fix citeseer dataset (there are some isolated nodes in the graph)
        # Find isolated nodes, add them as zero-vecs into the right position
        test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
        tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
        tx_extended[test_idx_range-min(test_idx_range), :] = tx
        tx = tx_extended
        ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
        ty_extended[test_idx_range-min(test_idx_range), :] = ty
        ty = ty_extended

    features = sp.vstack((allx, tx)).tolil()
    features[test_idx_reorder, :] = features[test_idx_range, :]
    adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))

    labels = np.vstack((ally, ty))
    labels[test_idx_reorder, :] = labels[test_idx_range, :]

    idx_test = test_idx_range.tolist()
    idx_train = range(len(y))
    idx_val = range(len(y), len(y)+500)

    return adj, features, labels, idx_train, idx_val, idx_test

def sparse_to_tuple(sparse_mx, insert_batch=False):
    """Convert sparse matrix to tuple representation."""
    """Set insert_batch=True if you want to insert a batch dimension."""
    def to_tuple(mx):
        if not sp.isspmatrix_coo(mx):
            mx = mx.tocoo()
        if insert_batch:
            coords = np.vstack((np.zeros(mx.row.shape[0]), mx.row, mx.col)).transpose()
            values = mx.data
            shape = (1,) + mx.shape
        else:
            coords = np.vstack((mx.row, mx.col)).transpose()
            values = mx.data
            shape = mx.shape
        return coords, values, shape

    if isinstance(sparse_mx, list):
        for i in range(len(sparse_mx)):
            sparse_mx[i] = to_tuple(sparse_mx[i])
    else:
        sparse_mx = to_tuple(sparse_mx)

    return sparse_mx

def standardize_data(f, train_mask):
    """Standardize feature matrix and convert to tuple representation"""
    # standardize data
    f = f.todense()
    mu = f[train_mask == True, :].mean(axis=0)
    sigma = f[train_mask == True, :].std(axis=0)
    f = f[:, np.squeeze(np.array(sigma > 0))]
    mu = f[train_mask == True, :].mean(axis=0)
    sigma = f[train_mask == True, :].std(axis=0)
    f = (f - mu) / sigma
    return f


def preprocess_features(features):
    """Row-normalize feature matrix and convert to tuple representation"""
    rowsum = np.array(features.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    features = r_mat_inv.dot(features)
    # return features, sparse_to_tuple(features)
    return features


def normalize_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()


def preprocess_adj(adj):
    """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation."""
    adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0]))
    return sparse_to_tuple(adj_normalized)

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


# 获得每个文档与哪些单词之间存在边：
def get_doc_word_edge(index):
    with open(r'F:\mypython\final_subject\bbdw\include_doc_single_word_out\connected_graph.txt', 'rb') as f:
        graph = pkl.load(f)

    doc_word_dict = {}
    for node in graph.nodes:
        if node.type == 'd':
            doc_index = index[node.name]
            words = []
            for word in node.neighbors:
                if index.get(word) is not None:
                    words.append(index[word])
            doc_word_dict[doc_index] = words
    return doc_word_dict


# 根据节点所在子图 求出单词和文档节点编号的范围, 以及子图编号的范围
# node_subgraph是一个字典，其中只有单词节点，文档节点从0开始编号，单词节点从文档节点之后编号
def get_doc_word_index_range(node_subgraph):
    word_index = []
    for k in node_subgraph.keys():
        word_index.append(k)
    word_index.sort()
    word_index_min = word_index[0]
    word_index_max = word_index[-1]
    doc_index_min = 0
    doc_index_max = word_index_min - 1

    graph_index = []
    for v in node_subgraph.values():
        graph_index.append(v)
    graph_index.sort()
    graph_index_min = graph_index[0]
    graph_index_max = graph_index[-1]

    return word_index_min, word_index_max, doc_index_min, doc_index_max, graph_index_min, graph_index_max


# 给定一个单词节点，计算其所属子图的emb
def get_graph_emb_by_word(node, node_subgraph, features, subgraph_node, batch_size, device):
    # 该节点所属的子图编号
    node_subgraph_index = node_subgraph[node]

    graph_emb = get_graph_emb_by_graph(node_subgraph_index, subgraph_node, features, batch_size, device)
    return graph_emb


# 根据子图编号获取子图的emb
# 参数： 子图编号，子图-节点字典，特征矩阵
def get_graph_emb_by_graph(node_subgraph_index, subgraph_node, features, batch_size, device):
    #     子图下的节点列表
    word_list = subgraph_node[node_subgraph_index]
    sum = torch.zeros((1, 1, features.shape[2])).to(device)
    for word_index in word_list:
        sum += features[:, word_index, :]
    sum /= len(word_list)
    return sum


# 根据文档节点，获得其对应的一个子图作为正样本即可
# 参数：文档节点编号，节点-子图字典，文档-单词边字典，特征矩阵，子图-节点字典
def get_graph_emb_by_doc(node, node_subgraph, doc_word_edge, features, subgraph_node, batch_size, device):
    # 获得单词列表
    neighbor = doc_word_edge[node]
    # 打乱列表
    np.random.shuffle(neighbor)
    choose_word = neighbor[0]
    sub_graph_index = node_subgraph[choose_word]
    emb = get_graph_emb_by_graph(sub_graph_index, subgraph_node, features, batch_size, device)
    return sub_graph_index, emb


def get_sub_graph_embs(graph_min, graph_max, subgraph_node, features):
    graph_num = graph_max - graph_min + 1
    embs = torch.zeros((graph_num, features.shape[1]))
    for i in range(graph_min, graph_max+1):
        embs[i] = get_graph_emb_by_graph_before_train(i, subgraph_node, features)
    return embs


# 预处理：根据子图编号获取子图的emb，和上面函数的区别在于维度不同
# 参数： 子图编号，子图-节点字典，特征矩阵
def get_graph_emb_by_graph_before_train(node_subgraph_index, subgraph_node, features):
    #     子图下的节点列表
    word_list = subgraph_node[node_subgraph_index]
    sum = torch.zeros((1, features.shape[1]))
    for word_index in word_list:
        sum += features[word_index, :]
    sum /= len(word_list)
    return sum


def get_positive_graph_emb_by_doc(node_index, node_subgraph, doc_word_edge, sub_graph_embs, sub_graph_embs_sum, device):
    # 获得单词列表
    neighbor = doc_word_edge[node_index]
    sub_graph_list = []
    for word in neighbor:
        sub_graph_list.append(node_subgraph[word])
    positive_emb = torch.zeros((1, sub_graph_embs.shape[1])).to(device)
    for positive_index in sub_graph_list:
        positive_emb += sub_graph_embs[positive_index]
    positive_emb = positive_emb / len(sub_graph_list)

    negative_emb = sub_graph_embs_sum.clone().to(device)
    for index in sub_graph_list:
        negative_emb -= sub_graph_embs[index]
    negative_emb = negative_emb / (sub_graph_embs.shape[0] - len(sub_graph_list))

    return positive_emb, negative_emb


# 获得文档所属的子图的列表
def get_doc_sub_graph(doc_index, doc_word_edge, node_subgraph):
    # 获得单词列表
    neighbor = doc_word_edge[doc_index]
    sub_graph_list = []
    for word in neighbor:
        sub_graph_list.append(node_subgraph[word])
    return sub_graph_list


if __name__ == '__main__':
    with open('../data/node_subgraph', 'rb') as f:
        node_subgraph = pkl.load(f)
    print(get_doc_word_index_range(node_subgraph))


